{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "182e4545",
   "metadata": {},
   "source": [
    "# Set up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2f898ed",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-24T18:01:21.475916Z",
     "start_time": "2024-03-24T18:01:20.396154Z"
    }
   },
   "outputs": [],
   "source": [
    "!pip install gitpython"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66345348",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-24T18:01:23.700909Z",
     "start_time": "2024-03-24T18:01:22.178013Z"
    }
   },
   "outputs": [],
   "source": [
    "import git\n",
    "\n",
    "from git import Repo\n",
    "\n",
    "git_url = 'https://github.com/AdhyaSuman/GINopic'\n",
    "repo_dir = 'GINopic_local'\n",
    "\n",
    "Repo.clone_from(git_url, repo_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61a06f99",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-24T19:17:33.087597Z",
     "start_time": "2024-03-24T19:17:33.084656Z"
    }
   },
   "outputs": [],
   "source": [
    "# Go to the home directory of the repo\n",
    "cd GINopic_local"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67bc703c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-25T07:33:07.264789Z",
     "start_time": "2024-03-25T07:33:05.198638Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "!pip install -e."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15f38e4b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-25T07:33:08.487379Z",
     "start_time": "2024-03-25T07:33:07.287794Z"
    }
   },
   "outputs": [],
   "source": [
    "#Install DGL\n",
    "!pip install  dgl -f https://data.dgl.ai/wheels/cu121/repo.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87d326b2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-25T07:33:09.784508Z",
     "start_time": "2024-03-25T07:33:08.510851Z"
    }
   },
   "outputs": [],
   "source": [
    "!pip install  dglgo -f https://data.dgl.ai/wheels-test/repo.html"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8d36960",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02943112",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-25T09:03:56.058676Z",
     "start_time": "2024-03-25T09:03:54.457866Z"
    }
   },
   "outputs": [],
   "source": [
    "from octis.dataset.dataset import Dataset\n",
    "\n",
    "#Import models:\n",
    "from octis.models.GINOPIC import GINOPIC\n",
    "\n",
    "#Import coherence metrics:\n",
    "from octis.evaluation_metrics.coherence_metrics import *\n",
    "\n",
    "#Import TD metrics:\n",
    "from octis.evaluation_metrics.diversity_metrics import *\n",
    "\n",
    "#Import classification metrics:\n",
    "from octis.evaluation_metrics.classification_metrics import *\n",
    "\n",
    "import random, torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b66bf757",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ac405c3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-25T09:03:57.347786Z",
     "start_time": "2024-03-25T09:03:57.345097Z"
    },
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "data_dir = './preprocessed_datasets'\n",
    "\n",
    "def get_dataset(dataset_name):\n",
    "    data = Dataset()\n",
    "    if dataset_name=='20NG':\n",
    "        data.fetch_dataset(\"20NewsGroup\")\n",
    "    elif dataset_name=='SO':\n",
    "        data.load_custom_dataset_from_folder(data_dir + \"/SO\")\n",
    "    elif dataset_name=='BBC':\n",
    "        data.fetch_dataset(\"BBC_News\")\n",
    "    elif dataset_name=='Bio':\n",
    "        data.load_custom_dataset_from_folder(data_dir + \"/Bio\")\n",
    "    elif dataset_name=='SearchSnippets':\n",
    "        data.load_custom_dataset_from_folder(data_dir + \"/SearchSnippets\")\n",
    "    else:\n",
    "        raise Exception('Missing Dataset name...!!!')\n",
    "    return data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4862a392",
   "metadata": {},
   "source": [
    "# Hyperparameter Tuning"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e9e14fb",
   "metadata": {},
   "source": [
    "**Skip this if you do not want to tune the hyperparameters**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de67d7f0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-25T07:22:46.518630Z",
     "start_time": "2024-03-25T07:22:46.158888Z"
    }
   },
   "outputs": [],
   "source": [
    "# #Define some parameters\n",
    "# dataset_name = '20NG'\n",
    "# K=20\n",
    "# use_partitions=True\n",
    "# use_validation=False\n",
    "\n",
    "# #Get the dataset\n",
    "# data = get_dataset(dataset_name)\n",
    "\n",
    "# #Define the Model\n",
    "# model = GINOPIC(num_topics=K,\n",
    "#                 use_partitions=use_partitions,\n",
    "#                 use_validation=use_validation,\n",
    "#                 num_epochs=50,\n",
    "#                 w2v_path='./w2v/{}_part{}_valid{}/'.format(dataset_name, use_partitions, use_validation),\n",
    "#                 graph_path= './doc_graphs/{}_part{}_valid{}/'.format(dataset_name, use_partitions, use_validation))\n",
    "        \n",
    "\n",
    "# #Coherence Score:\n",
    "# npmi = Coherence(texts=data.get_corpus(), topk=10, measure='c_npmi')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66d57c89",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-25T07:22:47.294526Z",
     "start_time": "2024-03-25T07:22:47.278569Z"
    }
   },
   "outputs": [],
   "source": [
    "# #Define the Search Space\n",
    "# from skopt.space.space import Real, Categorical, Integer\n",
    "\n",
    "# search_space = {\n",
    "#     \"g_feat_size\":    Categorical({64, 128, 256, 512, 768, 1024, 2048}), \n",
    "#     \"num_gin_layers\": Categorical({2, 3}), \n",
    "#     \"num_mlp_layers\": Categorical({1, 2, 3, 4, 5}),\n",
    "#     \"gin_hidden_dim\": Categorical({50, 100, 200, 300}),\n",
    "#     \"gin_output_dim\": Categorical({64, 128, 256, 512, 768, 1024, 2048}),\n",
    "#     \"eps_simGraph\":   Categorical({.0, .05, .1, .2, .3, .4, .5}),\n",
    "#                }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d6d9e4a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-25T07:32:42.546236Z",
     "start_time": "2024-03-25T07:22:48.222409Z"
    }
   },
   "outputs": [],
   "source": [
    "# #Define the Optimizer\n",
    "# from octis.optimization.optimizer import Optimizer\n",
    "\n",
    "# optimizer=Optimizer()\n",
    "# optimization_result = optimizer.optimize(\n",
    "#                         model,\n",
    "#                         data,\n",
    "#                         npmi,\n",
    "#                         search_space,\n",
    "#                         number_of_call=50, \n",
    "#                         model_runs=5,\n",
    "#                         save_models=False, \n",
    "#                         early_stop=False,\n",
    "#                         early_step=10,\n",
    "#                         plot_best_seen=False,\n",
    "#                         plot_model=False,\n",
    "#                         save_path='./H_optimization/{}/K_{}/'.format(dataset_name, K))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2248fc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# optimization_result.save_to_csv(\"./H_optimization/{}/K_{}/results.csv\".format(dataset_name, K))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9ab6262",
   "metadata": {},
   "source": [
    "# Run Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9100a2bb",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-25T09:04:49.914574Z",
     "start_time": "2024-03-25T09:04:39.376862Z"
    },
    "code_folding": [
     10
    ]
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from random import randint\n",
    "from IPython.display import clear_output\n",
    "\n",
    "seeds = [randint(0, 2e3) for _ in range(1)]\n",
    "\n",
    "n_topics = {\n",
    "    '20NG': [20, 50, 100],\n",
    "    'BBC': [5, 20, 50, 100],\n",
    "    'Bio': [20, 50, 100],\n",
    "    'SO': [20, 50, 100],\n",
    "    'SearchSnippets': [8, 20, 50, 100],\n",
    "}\n",
    "\n",
    "m = 'GINopic'\n",
    "datasets = ['20NG', 'BBC', 'Bio', 'SO', 'SearchSnippets']\n",
    "\n",
    "params = {\n",
    "    '20NG': {\n",
    "        'num_gin_layers': 2,\n",
    "        'g_feat_size': 2048,\n",
    "        'num_mlp_layers': 1,\n",
    "        'gin_hidden_dim': 200,\n",
    "        'gin_output_dim': 768,\n",
    "        'eps_simGraph': 0.4\n",
    "    },\n",
    "    'BBC': {\n",
    "        'num_gin_layers': 3,\n",
    "        'g_feat_size': 256,\n",
    "        'num_mlp_layers': 1,\n",
    "        'gin_hidden_dim': 50,\n",
    "        'gin_output_dim': 512,\n",
    "        'eps_simGraph': 0.3\n",
    "    },\n",
    "    'Bio': {\n",
    "        'num_gin_layers': 2,\n",
    "        'g_feat_size': 1024,\n",
    "        'num_mlp_layers': 1,\n",
    "        'gin_hidden_dim': 200,\n",
    "        'gin_output_dim': 256,\n",
    "        'eps_simGraph': 0.05\n",
    "    },\n",
    "    'SO': {\n",
    "        'num_gin_layers': 2,\n",
    "        'g_feat_size': 64,\n",
    "        'num_mlp_layers': 1,\n",
    "        'gin_hidden_dim': 300,\n",
    "        'gin_output_dim': 512,\n",
    "        'eps_simGraph': 0.1\n",
    "    },\n",
    "    'SearchSnippets': {\n",
    "        'num_gin_layers': 2,\n",
    "        'g_feat_size': 1024,\n",
    "        'num_mlp_layers': 1,\n",
    "        'gin_hidden_dim': 50,\n",
    "        'gin_output_dim': 256,\n",
    "        'eps_simGraph': 0.2\n",
    "    }\n",
    "}\n",
    "\n",
    "results = {\n",
    "    'Dataset': [],\n",
    "    'K': [],\n",
    "    'Seed': [],\n",
    "    'Model':[],\n",
    "    'NPMI': [],\n",
    "    'CV': [],\n",
    "    'Accuracy': []\n",
    "}\n",
    "\n",
    "irbo = InvertedRBO(topk=10, weight=.95)\n",
    "\n",
    "partition = True\n",
    "validation = False\n",
    "\n",
    "for seed in seeds:\n",
    "    for d in datasets:\n",
    "        for k in n_topics[d]:\n",
    "            data = get_dataset(d)\n",
    "\n",
    "            print('Results:-\\n', results)\n",
    "\n",
    "            print(\"-\"*100)\n",
    "            print('Dataset:{},\\t Model:{},\\t K={},\\t Seed={}'.format(d, m, k, seed))\n",
    "            print(\"-\"*100)\n",
    "\n",
    "            random.seed(seed)\n",
    "            torch.random.manual_seed(seed)\n",
    "\n",
    "            model = GINOPIC(num_topics=k,\n",
    "                 use_partitions=partition,\n",
    "                 use_validation=validation,\n",
    "                 num_epochs=50,\n",
    "                 w2v_path='./w2v/{}_part{}_valid{}/'.format(d, partition, validation),\n",
    "                 graph_path='./doc_graphs/{}_part{}_valid{}/'.format(d, partition, validation),\n",
    "                 num_gin_layers=params[d]['num_gin_layers'],\n",
    "                 g_feat_size=params[d]['g_feat_size'],\n",
    "                 num_mlp_layers=params[d]['num_mlp_layers'],\n",
    "                 gin_hidden_dim=params[d]['gin_hidden_dim'],\n",
    "                 gin_output_dim=params[d]['gin_output_dim'],\n",
    "                 eps_simGraph=params[d]['eps_simGraph']\n",
    "                )\n",
    "\n",
    "            output = model.train_model(dataset=data)\n",
    "\n",
    "            del model\n",
    "            torch.cuda.empty_cache()\n",
    "\n",
    "            #Hyperparams:\n",
    "            results['Dataset'].append(d)\n",
    "            results['Model'].append(m)\n",
    "            results['K'].append(k)\n",
    "            results['Seed'].append(seed)\n",
    "            #############\n",
    "\n",
    "            #Coherence Scores:\n",
    "            npmi = Coherence(texts=data.get_corpus(), topk=10, measure='c_npmi')\n",
    "            results['NPMI'].append(npmi.score(output))\n",
    "            del npmi\n",
    "\n",
    "            cv = Coherence(texts=data.get_corpus(), topk=10, measure='c_v')\n",
    "            results['CV'].append(cv.score(output))\n",
    "            del cv\n",
    "\n",
    "            #############\n",
    "            if partition==True:\n",
    "                #classification:\n",
    "                try:\n",
    "                    #Accuracy\n",
    "                    accuracy = AccuracyScore(data)\n",
    "                    results['Accuracy'].append(accuracy.score(output))\n",
    "                except:\n",
    "                    results['Accuracy'].append(0.0)\n",
    "            else:\n",
    "                results['Accuracy'].append(0.0)\n",
    "            #############\n",
    "            clear_output(wait=False)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00c2072b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
